#!/usr/bin/env python

'''Tests the driver cursor API'''

import os, random, re, socket, subprocess, sys, tempfile, unittest

sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir, "common"))
import utils

r = utils.import_python_driver()
dbName, tableName = utils.get_test_db_table()

rethinkdb_exe = sys.argv[1] if len(sys.argv) > 1 else utils.find_rethinkdb_executable()

# --

server = None
serverHost = os.environ.get('RDB_SERVER_HOST', 'localhost')
serverPort = int(os.environ.get('RDB_DRIVER_PORT', 28015))

# -- tests

class TestRangeCursor(unittest.TestCase):
    def setUp(self):
        self.conn = r.connect(host=serverHost, port=serverPort)

        if not hasattr(self, 'assertRaisesRegexp'):
            def assertRaisesRegexp_replacement(exception, regexp, function, *args, **kwds):
                try:
                    function(*args, **kwds)
                except Exception as e:
                    if not isinstance(e, Exception):
                        raise
                    if not re.match(regexp, str(e)):
                        raise AssertionError('"%s" does not match "%s"' % (str(regexp), str(e)))
                    return
                raise AssertionError('TypeError not raised for: %s' % str(function))
            self.assertRaisesRegexp = assertRaisesRegexp_replacement

    def test_cursor_after_connection_close(self):
        cursor = r.range().run(self.conn)
        self.conn.close()
        def read_cursor(cursor):
            count = 0
            for i in cursor:
                count += 1
            return count
        count = self.assertRaisesRegexp(r.ReqlRuntimeError, "Connection is closed.", read_cursor, cursor)
        self.assertNotEqual(count, 0, "Did not get any cursor results")

    def test_cursor_after_cursor_close(self):
        cursor = r.range().run(self.conn)
        cursor.close()
        count = 0
        for i in cursor:
            count += 1
        self.assertNotEqual(count, 0, "Did not get any cursor results")

    def test_cursor_close_in_each(self):
        cursor = r.range().run(self.conn)
        count = 0
        for i in cursor:
            count += 1
            if count == 2:
                cursor.close()
        self.assertTrue(count > 2, "Did not get enough cursor results")

    def test_cursor_success(self):
        range_size = 10000
        cursor = r.range().limit(range_size).run(self.conn)
        count = 0
        for i in cursor:
            count += 1
        self.assertEqual(count, range_size, "Expected %d results on the cursor, but got %d" % (range_size, count))

    def test_cursor_double_each(self):
        range_size = 10000
        cursor = r.range().limit(range_size).run(self.conn)
        count = 0
        for i in cursor:
            count += 1
        self.assertEqual(count, range_size, "Expected %d results on the cursor, but got %d" % (range_size, count))
        for i in cursor:
            count += 1
        self.assertEqual(count, range_size, "Expected no results on the second iteration of the cursor, but got %d" % (count - range_size))

class TestCursor(unittest.TestCase):

    conn = None
    num_rows = random.randint(1111, 2222)

    def setupTable(self):

        # - ensure a clean table

        if dbName not in r.db_list().run(self.conn):
            r.db_create(dbName).run(self.conn)

        if tableName in r.db(dbName).table_list().run(self.conn):
            r.db(dbName).table_drop(tableName).run(self.conn)
        r.db(dbName).table_create(tableName).run(self.conn)

        # - populate the table with the requested data

        res = r.db(dbName).table(tableName).insert(r.range(self.num_rows).map(lambda x: {})).run(self.conn)
        assert res['inserted'] == self.num_rows

    def setUp(self):

        if self.conn is None:
            self.conn = r.connect(host=serverHost, port=serverPort)
            self.setupTable()

        # - supply a cursor to the full table

        self.cur = r.db(dbName).table(tableName).run(self.conn)

    def test_type(self):
        self.assertTrue(isinstance(self.cur, r.Cursor))

    def test_count(self):
        i = 0
        for _ in self.cur:
            i += 1

        self.assertEqual(i, self.num_rows)

    def test_close(self):
        # This excercises a code path at the root of #650
        self.conn.close()

class TestCursorWait(unittest.TestCase):

    conn = None
    num_cursors = 3

    def setUp(self):
        if self.conn is None:
            self.conn = r.connect(host=serverHost, port=serverPort)

    def do_test(self, wait_time):
        cursors = [ ]
        cursor_counts = [ ]
        cursor_timeouts = [ ]
        for i in range(self.num_cursors):
            cursors.append(r.range().map(r.js("(function (row) {" +
                                                  "end = new Date(new Date().getTime() + 2);" +
                                                  "while (new Date() < end) { }" +
                                                  "return row;" +
                                              "})")).run(self.conn, max_batch_rows=500))
            cursor_counts.append(0)
            cursor_timeouts.append(0)

        def get_next(cursor_index):
            try:
                if wait_time is None: # Special case to use the default
                    cursors[cursor_index].next()
                else:
                    cursors[cursor_index].next(wait=wait_time)
                cursor_counts[cursor_index] += 1
            except r.ReqlTimeoutError:
                cursor_timeouts[cursor_index] += 1

        # We need to get ahead of pre-fetching for this to get the error we want
        while sum(cursor_counts) < 4000:
            for cursor_index in range(self.num_cursors):
                for read_count in range(random.randint(0, 10)):
                    get_next(cursor_index)

        [cursor.close() for cursor in cursors]

        return (sum(cursor_counts), sum(cursor_timeouts))

    def test_false_wait(self):
        reads, timeouts = self.do_test(False)
        self.assertNotEqual(timeouts, 0, "Did not get timeouts using zero (false) wait.")

    def test_zero_wait(self):
        reads, timeouts = self.do_test(0)
        self.assertNotEqual(timeouts, 0, "Did not get timeouts using zero wait.")

    def test_short_wait(self):
        reads, timeouts = self.do_test(0.0001)
        self.assertNotEqual(timeouts, 0, "Did not get timeouts using short (100 microsecond) wait.")

    def test_long_wait(self):
        reads, timeouts = self.do_test(10)
        self.assertEqual(timeouts, 0, "Got timeouts using long (10 second) wait.")

    def test_infinite_wait(self):
        reads, timeouts = self.do_test(True)
        self.assertEqual(timeouts, 0, "Got timeouts using infinite wait.")

    def test_default_wait(self):
        reads, timeouts = self.do_test(None)
        self.assertEqual(timeouts, 0, "Got timeouts using default (infinite) wait.")

class TestChangefeedWait(unittest.TestCase):
    def test_wait(self):
        # - ensure a clean table
        conn = r.connect(host=serverHost, port=serverPort)
        if dbName not in r.db_list().run(conn):
            r.db_create(dbName).run(conn)
        if tableName not in r.db(dbName).table_list().run(conn):
            r.db(dbName).table_create(tableName).run(conn)

        changes = r.db(dbName).table(tableName).changes().run(conn)
        self.assertRaises(r.ReqlTimeoutError, changes.next, wait=0)
        self.assertRaises(r.ReqlTimeoutError, changes.next, wait=0.2)
        self.assertRaises(r.ReqlTimeoutError, changes.next, wait=1)
        self.assertRaises(r.ReqlTimeoutError, changes.next, wait=5)
        res = r.db(dbName).table(tableName).insert({}).run(conn)
        self.assertEqual(res['inserted'], 1)
        res = changes.next(wait=1)

if __name__ == '__main__':
    print("Testing cursor for %d rows" % TestCursor.num_rows)
    suite = unittest.TestSuite()
    loader = unittest.TestLoader()
    suite.addTest(loader.loadTestsFromTestCase(TestCursor))
    suite.addTest(loader.loadTestsFromTestCase(TestRangeCursor))
    suite.addTest(loader.loadTestsFromTestCase(TestCursorWait))
    suite.addTest(loader.loadTestsFromTestCase(TestChangefeedWait))

    if not unittest.TextTestRunner(verbosity=2).run(suite):
        sys.exit(1)
